import os
import json
from tqdm import tqdm
import torch
import open_clip
import random
import numpy as np

def load_landmark_data(dataset, split):
    data_dir = f"../HOV-SG/train_data_prepare/{dataset}_landmarks/{split}_objs.json"
    with open(data_dir, 'r') as f:
        landmark_data = json.load(f)
    return landmark_data

def load_all_vp_feats(scan):
    path = f"../ce_views/vp_feats/{scan}.json"
    with open(path, 'r') as f:
        vp_feats_dict = json.load(f)
    # Convert everything to tensor once
    vps = list(vp_feats_dict.keys())
    all_vp_feats = []
    for vp in vps:
        all_vp_feats.append(torch.tensor(vp_feats_dict[vp]).cuda())
    all_vp_feats = torch.stack(all_vp_feats, dim=0)
    return vps, all_vp_feats

def locate_room(target_name, room_types, room_type_feats, clip_model, tokenizer, vps, all_vp_feats, device):
    if target_name in room_types:
        all_room_types = room_types
        text_features = room_type_feats
    else:
        all_room_types = room_types + [target_name]
        target_text = tokenizer(["a photo of a " + target_name]).cuda()
        with torch.no_grad(), torch.autocast(device):
            target_feature = clip_model.encode_text(target_text)
        text_features = torch.cat([room_type_feats, target_feature], dim=0)
    
    with torch.no_grad(), torch.autocast(device):
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)
        similarity = (100.0 * all_vp_feats @ text_features.T).softmax(dim=-1)

    top1_indices = similarity.topk(1, dim=-1).indices.reshape(len(vps), 4).cpu().numpy()
    target_idx = all_room_types.index(target_name)
    mask = np.any(top1_indices == target_idx, axis=1)
    candidate_idx = np.where(mask)[0]
    candidates = np.array(vps)[candidate_idx]

    if len(candidates) < 3:
        target_score = similarity[:, :, target_idx]
        max_scores, _ = target_score.max(dim=1)
        _, top3_indices = max_scores.topk(3)
        candidates = np.concatenate([candidates, np.array(vps)[top3_indices.cpu().numpy()]])

    return candidates.tolist()

def load_obj_feats(scan):
    with open(f"../HOV-SG/openscene_gt_objs/results/{scan}.json", 'r') as f:
        objs = json.load(f)
    obj_ids = list(range(len(objs)))
    obj_feats = [torch.tensor(obj['avg_feature']).cuda() for obj in objs]
    obj_feats = torch.stack(obj_feats, dim=0).cpu().numpy()
    obj_ids = np.array(obj_ids)
    return obj_ids, obj_feats

def locate_object(target_name, obj_ids, object_feats, obj_types, obj_type_feats, clip_model, tokenizer, scan, device):
    if target_name in obj_types:
        all_obj_types = obj_types
        text_features = obj_type_feats
    else:
        all_obj_types = obj_types + [target_name]
        target_text = tokenizer(["a photo of " + target_name]).cuda()
        with torch.no_grad(), torch.autocast(device):
            target_feature = clip_model.encode_text(target_text)
        text_features = torch.cat([obj_type_feats, target_feature], dim=0)

    with torch.no_grad(), torch.autocast(device):
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

    text_feats = np.array([text_features.cpu().numpy()])
    sim_mat = np.dot(text_feats, object_feats.T)[0].T
    top1_text_indices = np.argsort(sim_mat, axis=1)[:, -1:][:, ::-1]
    target_idx = all_obj_types.index(target_name)
    mask = np.any(top1_text_indices == target_idx, axis=1)
    candidate_idx = np.where(mask)[0]
    candidates = obj_ids[candidate_idx]

    if len(candidates) < 3:
        target_score = sim_mat[:, target_idx]
        top_indices = np.argsort(target_score)[-3:][::-1]
        candidates = np.concatenate([candidates, obj_ids[top_indices]])

    return candidates.tolist()


def main():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    clip_model, _, _ = open_clip.create_model_and_transforms(
        "ViT-H-14",
        pretrained="../HOV-SG/checkpoints/laion2b_s32b_b79k.bin",
        device="cuda",
    )
    clip_model.eval()
    tokenizer = open_clip.get_tokenizer("ViT-H-14")

    with open("../HOV-SG/Navigation/mp3d_room_types.txt", "r") as f:
        room_types = f.readlines()
    room_types = [r.strip() for r in room_types]
    room_type_feats = tokenizer(["a photo of a " + r for r in room_types]).cuda()
    with torch.no_grad(), torch.autocast(device):
        room_type_feats = clip_model.encode_text(room_type_feats)

    with open("../HOV-SG/Navigation/mp40_cat.txt", "r") as f:
        obj_types = f.readlines()
    obj_types = [r.strip() for r in obj_types]
    obj_type_feats = tokenizer(["a photo of " + r for r in obj_types]).cuda()
    with torch.no_grad(), torch.autocast(device):
        obj_type_feats = clip_model.encode_text(obj_type_feats)

    split = 'val_unseen'
    save_dir = "../HOV-SG/node_generation_gt/target_candidates"
    lm_data = load_landmark_data('REVERIE', split)
    
    pbar = tqdm(total=240, desc=f"Processing {split}")
    for score_threshold in [0.0, 0.20, 0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 0.28, 0.29, 0.30]:
        for rank_threshold in range(5, 55, 5):
            for fill_to_rank in [False, True]:
                if fill_to_rank:
                    min_candidates = rank_threshold
                else:
                    min_candidates = 1

                results = {}
                save_path = os.path.join(save_dir, f"{int(score_threshold*100)}_rank_{rank_threshold}_min_{min_candidates}.json")
                os.makedirs(os.path.dirname(save_path), exist_ok=True)

                for scan, scan_info in tqdm(lm_data.items(), desc=f"Processing {split}", total=len(lm_data), position=0):
                    results[scan] = {}
                    os.makedirs(os.path.dirname(save_path), exist_ok=True)

                    vps, all_vp_feats = load_all_vp_feats(scan)
                    obj_ids, obj_feats = load_obj_feats(scan)

                    for inst_id, inst_info in tqdm(scan_info.items(), desc=f"Processing {scan}", total=len(scan_info), position=1, leave=False):
                        all_landmarks = inst_info['landmarks']
                        landmarks, target = all_landmarks.split('\n')[:-1], all_landmarks.split('\n')[-1]

                        landmarks =  [x.split('.')[1].strip() for x in landmarks]
                        target = target.split(':')[1].strip().split('(')[0].strip()
                        
                        landmark_names = [x.split('(')[0].strip() for x in landmarks]
                        landmark_types = [x.split('(')[1].split(')')[0].strip() for x in landmarks]

                        for landmark, landmark_type in zip(landmark_names, landmark_types):
                            if landmark != target:
                                continue

                            if 'floor' in landmark_type.lower():
                                candidates = ['floor']
                            elif 'room' in landmark_type.lower():
                                candidates = locate_room(landmark, room_types, room_type_feats, clip_model, tokenizer, vps, all_vp_feats, device)
                            elif 'object' in landmark_type.lower():
                                candidates = locate_object(landmark, obj_ids, obj_feats, obj_types, obj_type_feats, clip_model, tokenizer, scan, device)
                            else:
                                raise ValueError(f"Unknown landmark type: {landmark_type}")

                            # assert inst_id not in results[scan], f"Duplicate instance ID {inst_id} in scan {scan}"
                            results[scan][inst_id] = candidates

                with open(save_path, 'w') as f:
                    json.dump(results, f, indent=4)
                pbar.update(1)

    pbar.close()
                
if __name__ == '__main__':
    main()